import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import sys
from matplotlib.patches import Patch
import matplotlib.lines as mlines
sys.path.insert(1,'/REDACTED/fairness/code/config')
from configureColors import *
sys.path.insert(1,'/REDACTED/fairness/code/utilities')
import weightedBinscatter as wb
import estimators as et
#sys.path.append('packages')
#import binscatter
import matplotlib.font_manager as fm
import matplotlib
import statsmodels.formula.api as smf
from statsmodels.iolib.smpickle import load_pickle
import joblib
from trajectoryPlotters import *

# set plot defaults
plt.style.use('/REDACTED/fairness/code/config/fairness.mplstyle')
fe = fm.FontEntry(
    fname='/REDACTED/fairness/code/config/fonts/cmunrm.ttf',
    name='latex')
fm.fontManager.ttflist.insert(0, fe)
matplotlib.rcParams['font.family'] = fe.name

colorsBNB = cdictBlackNonBlack()
colorsENE = cdictEITCNon()
cb = colorsBNB['Black']['color']
ab = colorsBNB['Black']['alpha']
cnb = colorsBNB['non-Black']['color']
anb = colorsBNB['non-Black']['alpha']
ce = colorsENE['eitc']['color']
cne = colorsENE['non']['color']
ae = colorsENE['eitc']['alpha']
ane = colorsENE['non']['alpha']

## Linear Estimator
def regEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
	model = sm.OLS.from_formula(outcome+' ~ '+pbvarb, data = dataset).fit(cov_type = 'HC1')
	coef = model.params[pbvarb]
	se = model.bse[pbvarb]
	black_aud_rate = model.params[0] + model.params[1]
	non_black_aud_rate = model.params[0] 
	return coef, se, black_aud_rate, non_black_aud_rate

## Probabilistic Estimator 
def chenEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
	black_aud_rate = (dataset[pbvarb]*dataset[outcome]).sum()/(dataset[pbvarb]).sum()
	non_black_aud_rate = ((1-dataset[pbvarb])*dataset[outcome]).sum()/(1-dataset[pbvarb]).sum()
	est =  black_aud_rate - non_black_aud_rate
	return est, black_aud_rate, non_black_aud_rate

## Get standard errors
def getSEs(dataset,pbvarb='predicted_prob_black',outcome='audited', seReg = 0.0):
	seMultiplier=getSEMultiplier(dataset,pbvarb)
	#seReg = regEstimate(dataset,pbvarb,outcome)[1]
	seChen = seReg*seMultiplier
	return seReg, seChen

def getSEMultiplier(dataset,pbvarb='predicted_prob_black'):
	return np.sqrt(dataset[pbvarb].var()/(dataset[pbvarb].mean()*(1-dataset[pbvarb].mean())))

### OUTPUT

out = '/REDACTED/'

### OP AUDIT
dataBISG = pd.read_csv("/REDACTED/data/final/individualBISG2014_full_final.csv")
dataBISG= dataBISG[~dataBISG.predicted_prob_black.isna()]

#########################
#### Fig 4: Audit Rates by Income and Race
#########################

# filter to taxpayers with non-negative AGI, split into 100 bins based on AGI
dataBISG_noneg_agi = dataBISG[dataBISG['adj_gross_inc']>=0]
dataBISG_noneg_agi ["bin_100"] = pd.qcut(dataBISG_noneg_agi ["adj_gross_inc"], q = 100, labels = list(range(1,100+1)))
bin_100_mean = dataBISG_noneg_agi .groupby("bin_100")["adj_gross_inc"].mean()

lin_ar_100 = []
lin_ar_b_100 = []
lin_ar_nb_100 = []
lin_ar_se_100 = []
prob_ar_100 = []
prob_ar_b_100 = []
prob_ar_nb_100 = []
prob_ar_se_100 = []
# loop over all bins and calculate outcomes of interest
for bin in range(1,len(dataBISG_noneg_agi .bin_100.value_counts())+1):
	temp_100 = dataBISG_noneg_agi [dataBISG_noneg_agi ["bin_100"] == bin]
	if len(temp_100) > 0:
                # linear estimate of audit rate
        	lin_aud_bin_100,lin_aud_se_100, lin_black_aud_bin_100, lin_non_black_aud_bin_100 = regEstimate(temp_100, pbvarb = "predicted_prob_black", outcome = "aud_no_research_audits")
        	lin_ar_100.append(lin_aud_bin_100)
        	lin_ar_b_100.append(lin_black_aud_bin_100)
        	lin_ar_nb_100.append(lin_non_black_aud_bin_100)
        	lin_ar_se_100.append(lin_aud_se_100)
                # probabilistic estimate of audit rate
        	prob_aud_bin_100, prob_black_aud_bin_100, prob_non_black_aud_bin_100 = chenEstimate(temp_100, pbvarb = "predicted_prob_black", outcome = "aud_no_research_audits")
        	prob_aud_se_100 = getSEs(temp_100,pbvarb="predicted_prob_black",outcome="aud_no_research_audits", seReg = lin_aud_se_100)[1]
        	prob_ar_100.append(prob_aud_bin_100)
        	prob_ar_b_100.append(prob_black_aud_bin_100)
        	prob_ar_nb_100.append(prob_non_black_aud_bin_100)
        	prob_ar_se_100.append(prob_aud_se_100)

# compile results in probabilistic and linear dataframes
lin_disp_df = pd.DataFrame(list(zip(lin_ar_100, lin_ar_b_100, lin_ar_nb_100, lin_ar_se_100, bin_100_mean)), columns = ["audit_rate","audit_rate_black","audit_rate_non_black","audit_rate_se", "bin_mean"])
prob_disp_df = pd.DataFrame(list(zip(prob_ar_100, prob_ar_b_100, prob_ar_nb_100, prob_ar_se_100, bin_100_mean)), columns = ["audit_rate","audit_rate_black","audit_rate_non_black","audit_rate_se", "bin_mean"])


### Left panel, probabilistic
prob_disp_df['audit_rate_black'] = prob_disp_df['audit_rate_black'] * 100
prob_disp_df['audit_rate_non_black'] = prob_disp_df['audit_rate_non_black'] * 100 
prob_disp_df['bin'] = list(range(1,101))

# plot results
ax = plt.subplot(111)
arb = ax.plot(prob_disp_df.bin, prob_disp_df.audit_rate_black, '-o', color = 'purple', alpha = 1, markersize = 2, label = "Black")
arnb = ax.plot(prob_disp_df.bin, prob_disp_df.audit_rate_non_black, '-o', color = 'purple', alpha = 0.4, markersize = 2, label = "Non-Black")
ax.set_ylabel('Audit Rate (%)')
ax.set_xlabel('Income Percentile')
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles)
plt.savefig('/REDACTED/prob_disp_inc_plot_100_full.png')
plt.close()


### Right panel, linear
lin_disp_df['audit_rate_black'] = lin_disp_df['audit_rate_black'] * 100
lin_disp_df['audit_rate_non_black'] = lin_disp_df['audit_rate_non_black'] * 100 
lin_disp_df['bin'] = list(range(1,101))

# plot results
ax = plt.subplot(111)
arb = ax.plot(lin_disp_df.bin, lin_disp_df.audit_rate_black, '-o', color = 'purple', alpha = 1, markersize = 2, label = "Black")
arnb = ax.plot(lin_disp_df.bin, lin_disp_df.audit_rate_non_black, '-o', color = 'purple', alpha = 0.4, markersize = 2, label = "Non-Black")
ax.set_ylabel('Audit Rate (%)')
ax.set_xlabel('Income Percentile')
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles)
plt.savefig('/REDACTED/lin_disp_inc_plot_100_full.png')
plt.close()